import torch
import torch.nn as nn
from data_utils import *
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

from tqdm import tqdm
from IPython import embed
import numpy as np
from numpy import linalg as LA
import argparse, os, sys, subprocess
import copy

from deepset import *
from cdm import *

def model_eval_versus(model, loader, batch_size, criterion, cutoff):
    with torch.no_grad():
        total_loss = 0.0
        accuracy = 0.0
        processed_no = 0

        for j, (team1, team2, outcome) in enumerate(loader):
            if team1.shape[0] != batch_size: continue

            model.zero_grad()
            out = model(team1, team2)#outcome wrt team 1
        
            loss = criterion(out, torch.tensor(outcome)).float()
            total_loss += loss.data
            winners = out.data.numpy() > cutoff
            accuracy += np.average(winners == outcome.data.numpy())
            processed_no += 1

    return total_loss / processed_no , accuracy / processed_no 

def compute_l1(l1_loss, model, factor):
    reg_loss = 0
    for param in model.parameters():
        reg_loss += l1_loss(param, target=torch.zeros_like(param))
    return factor * reg_loss

def columnize(res, loader_type):
    rows = []
    for team1, team2, score, pred in res:
        team1 = team1.numpy()
        team2 = team2.numpy()
        score = score.numpy()
        pred = pred.numpy()
        for i in range(len(team1)):
            t1 = " ".join([str(x) for x in team1[i]])
            t2 = " ".join([str(x) for x in team2[i]])
            s = str(int(score[i][0]))
            p = str(pred[i][0])
            rows.append(",".join([loader_type, t1 + " " + t2,s,p]))
    return rows

def eval_all(model, loader, batch_size, criterion):
    with torch.no_grad():
        total_loss = 0.0
        processed_no = 0

        result = []
        for j, (team1, team2, outcome) in enumerate(loader):
            if team1.shape[0] != batch_size: continue

            model.zero_grad()
            out = model(team1, team2)

            result.append((team1, team2, outcome, out))

            loss = criterion(out, torch.tensor(outcome)).float()
            total_loss += loss.data
            processed_no += 1

    return result, total_loss / processed_no

def save_predictions(best_model, loaders, batch_size, loss_function, filename="cdm_nba_prediction.csv"):

    train_loader, val_loader, test_loader = loaders

    train_res, l = eval_all(best_model, train_loader, batch_size, loss_function)
    print(l)
    train_perf = columnize(train_res, "train")
    val_res, l = eval_all(best_model, val_loader, batch_size, loss_function)
    print(l)
    val_perf = columnize(val_res, "val")
    test_res, l = eval_all(best_model, test_loader, batch_size, loss_function)
    print(l)
    test_perf = columnize(test_res, "test")
    all_perf = train_perf + val_perf + test_perf

    with open(filename, "w") as f:
        f.write("\n".join(all_perf))


def train(args):

    cutoff = 0.0 if args.regress else 0.5

    np.random.seed(0)

    #split into 85% train, 15% val
    train_set = DataSet(args.train_path, train_split=0.85)
    test_set = DataSet(args.test_path)

    train_indices, val_indices = train_set.get_split_indices()
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)
    
    train_loader = DataLoader(train_set, batch_size=args.batch_size, sampler=train_sampler)
    val_loader = DataLoader(train_set, batch_size=args.batch_size, sampler=valid_sampler)
    test_loader = DataLoader(test_set, batch_size=args.batch_size)

    if args.model == "fhoi": model = FHoi(args.num_players)
    elif args.model == "cdm": model = CDM(args.num_players, args.embed_size)
    elif args.model == "linear": model = LR(args.num_players)
    else: model = DeepSet(args.num_players, args.embed_size, linear_dim=args.linear_dim)

    print(model)

    if args.load_path: model.load_state_dict(torch.load(args.load_path))

    loss_function = nn.BCELoss()
    if args.regress: loss_function = nn.MSELoss() 
    l1_loss = nn.L1Loss()

    optimizer = torch.optim.SGD(model.parameters(), lr=args.learn_rate, weight_decay=args.l2_regularization)
  
    model_eval = model_eval_versus
    avg_loss, acc = model_eval(model, train_loader, args.batch_size, loss_function, cutoff) 
    print("start train loss {}".format(avg_loss)) 
    print("start train accuracy is {} \n".format(acc))
    avg_loss, acc = model_eval(model, val_loader, args.batch_size, loss_function, cutoff) 
    print("start validation loss {}".format(avg_loss)) 
    print("start validation accuracy is {} \n".format(acc))

    best_acc  = float("-inf")
    for epoch in range(args.num_epochs):
        total_loss = 0.0
        accuracy = 0.0
        processed_no = 0
        for i, (team1, team2, outcome) in enumerate(train_loader):

            if team1.shape[0] != args.batch_size: continue
            model.zero_grad()
            out = model(team1, team2)#outcome wrt team 1

            if epoch % args.print_train == 0:
                winners = out.data.numpy() > cutoff
                accuracy += np.average(winners == outcome.data.numpy())
                processed_no += 1

            loss = loss_function(out, torch.tensor(outcome)).float()
            if args.l1_regularization > 0: loss += compute_l1(l1_loss, model, factor=args.l1_regularization)
            loss.backward()
            optimizer.step()
    
            total_loss += loss.data 

        if epoch % args.print_train == 0 and args.verbose: 
            print("train epoch {} loss {}".format(epoch, total_loss / processed_no))
            print("train accuracy is {} \n".format(accuracy / processed_no))

        if epoch % args.eval_iter == 0:
           
            avg_loss, acc = model_eval(model, val_loader, args.batch_size, loss_function, cutoff) 
            print("validation epoch {} loss {}".format(epoch, avg_loss)) 
            print("validation accuracy is {} \n".format(acc))

            if acc > best_acc:
                best_model_weights = copy.deepcopy(model.state_dict())
                best_acc = acc 

    if args.model == "fhoi": best_model = FHoi(args.num_players)
    elif args.model == "cdm": best_model = CDM(args.num_players, args.embed_size)
    elif args.model == "linear": best_model = LR(args.num_players)
    else: best_model = DeepSet(args.num_players, args.embed_size, linear_dim=args.linear_dim)

    best_model.load_state_dict(best_model_weights)
    test_loss, test_acc = model_eval(best_model, test_loader, args.batch_size, loss_function, cutoff)

    '''
    save_predictions(best_model, [train_loader, val_loader, test_loader], args.batch_size, loss_function, filename="cdm_nba_prediction.csv") 
    '''

    if len(args.save_path) > 0: torch.save(best_model_weights, args.save_path) 

    print(args.train_path)
    print("test epoch {} loss {}".format(epoch, test_loss))
    print("test accuracy is {} \n".format(test_acc))

    return test_acc

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--num_epochs', '-epochs', type=int, default=100)
    parser.add_argument('--eval_iter', '-eval', type=int, default=20, help="Evaluate accuracy every eval_iter")
    parser.add_argument('--print_train', '-print_train', type=int, default=20, help="how often print training loss")
    parser.add_argument('--batch_size', '-batch_size', type=int, default=50, help="Batch size")
    parser.add_argument('--learn_rate', '-lr', type=float, default=1e-3, help="Learning rate")
    parser.add_argument('--l2_regularization', '-l2_lambda', type=float, default=1e-3, help="L2 regularization")
    parser.add_argument('--l1_regularization', '-l1_lambda', type=float, default=0.0, help="L1 regularization")
    parser.add_argument('--embed_size', '-embed_size', type=int, default=2, help="Embedding size")
    parser.add_argument('--linear_dim', '-linear_dim', type=int, default=10, help="Linear dim")
    parser.add_argument('--num_players', '-num_players', type=int, default=725, help="Number of players")
    parser.add_argument('--train_path', '-train_path', type=str, default="data/starter_nba_train.txt", help="File path")
    parser.add_argument('--test_path', '-test_path', type=str, default="data/starter_nba_test.txt", help="File path")
    parser.add_argument('--save_path', '-save_path', type=str, default="", help="Save path")
    parser.add_argument('--load_path', '-load_path', type=str, default="", help="Load path")
    parser.add_argument('--regress', '-regress', type=str, default="False", help="regress score difference")
    parser.add_argument('--verbose', '-verbose', type=str, default="True", help="verbosity")
    parser.add_argument('--data_format', '-data_format', type=str, default="matchup", help="Dataset format")
    parser.add_argument('--model', '-model', type=str, default="fhoi", help="model")

    args = parser.parse_args()
    args.regress = args.regress == "True"
    args.verbose = args.verbose == "True"
    print(args)

    train(args)

